{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "import warnings\n",
    "\n",
    "import numpy as np\n",
    "from sklearn.metrics import accuracy_score, roc_auc_score\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "from pytorch_tabular.utils import make_mixed_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "data, cat_col_names, num_col_names = make_mixed_dataset(\n",
    "    task=\"classification\", n_samples=10000, n_features=20, n_categories=4\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Importing the Library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from pytorch_tabular import TabularModel\n",
    "from pytorch_tabular.models import (\n",
    "    CategoryEmbeddingModelConfig,\n",
    ")\n",
    "from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig\n",
    "from pytorch_tabular.models.common.heads import LinearHeadConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train, test = train_test_split(data, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Cross Validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "Collapsed": "false",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data_config = DataConfig(\n",
    "    target=[\n",
    "        \"target\"\n",
    "    ],  # target should always be a list. Multi-targets are only supported for regression. Multi-Task Classification is not implemented\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names,\n",
    ")\n",
    "trainer_config = TrainerConfig(\n",
    "    batch_size=1024,\n",
    "    max_epochs=100,\n",
    "    early_stopping=\"valid_loss\",  # Monitor valid_loss for early stopping\n",
    "    early_stopping_mode=\"min\",  # Set the mode as min because for val_loss, lower is better\n",
    "    early_stopping_patience=5,  # No. of epochs of degradation training will wait before terminating\n",
    "    checkpoints=\"valid_loss\",  # Save best checkpoint monitoring val_loss\n",
    "    load_best=True,  # After training, load the best checkpoint\n",
    "    progress_bar=\"none\",  # Turning off Progress bar\n",
    "    trainer_kwargs=dict(enable_model_summary=False),  # Turning off model summary\n",
    ")\n",
    "optimizer_config = OptimizerConfig()\n",
    "\n",
    "head_config = LinearHeadConfig(\n",
    "    layers=\"\",\n",
    "    dropout=0.1,\n",
    "    initialization=(  # No additional layer in head, just a mapping layer to output_dim\n",
    "        \"kaiming\"\n",
    "    ),\n",
    ").__dict__  # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)\n",
    "\n",
    "model_config = CategoryEmbeddingModelConfig(\n",
    "    task=\"classification\",\n",
    "    layers=\"1024-512-512\",  # Number of nodes in each layer\n",
    "    activation=\"LeakyReLU\",  # Activation between each layers\n",
    "    learning_rate=1e-3,\n",
    "    head=\"LinearHead\",  # Linear Head\n",
    "    head_config=head_config,  # Linear Head Config\n",
    ")\n",
    "\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    "    verbose=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using High-Level API\n",
    "\n",
    "We can use the high level method `cross_validate` in `TabularModel`'\n",
    "\n",
    "The arguments are as follows:\n",
    "- `cv` can either be an integer or a `KFold` object. If it is an integer, it will be treated as the number of folds in a KFold. For classification problems, it will be a StratifiedKFold. if its a `KFold` object, it will be used as is.\n",
    "- `metric` is the metric to be used for evaluation. It can either be a string (name of the metric) or a callable. If it is a callable, it should take in two arguments, the predictions and the targets. The predictions should be the dataframe output from the `model.predict` and the target can be a series or an array.\n",
    "- `train` is the training dataset. \n",
    "- `return_oof` is a boolean. If set to True, it will return the out-of-fold predictions for the training dataset. This is useful for stacking models.\n",
    "- `reset_datamodule` is a boolean. If set to True, it will reset the datamodule after each fold, and is the right way of doing cross validation. If set to False, it will not reset the datamodule and will be faster, but will have a small amount of data leakage. This is useful when working with huge datasets and you want to save time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">13:08:18</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">468</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1925</span><span style=\"font-weight: bold\">}</span> - INFO - Running Fold <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>/<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>                           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m31\u001b[0m \u001b[1;92m13:08:18\u001b[0m,\u001b[1;36m468\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1925\u001b[0m\u001b[1m}\u001b[0m - INFO - Running Fold \u001b[1;36m1\u001b[0m/\u001b[1;36m2\u001b[0m                           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "You are using a CUDA device ('NVIDIA GeForce RTX 3060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">13:08:22</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">376</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1952</span><span style=\"font-weight: bold\">}</span> - INFO - Fold <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span>/<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> score: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.908</span>                      \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m31\u001b[0m \u001b[1;92m13:08:22\u001b[0m,\u001b[1;36m376\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1952\u001b[0m\u001b[1m}\u001b[0m - INFO - Fold \u001b[1;36m1\u001b[0m/\u001b[1;36m2\u001b[0m score: \u001b[1;36m0.908\u001b[0m                      \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">13:08:22</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">383</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1925</span><span style=\"font-weight: bold\">}</span> - INFO - Running Fold <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>/<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>                           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m31\u001b[0m \u001b[1;92m13:08:22\u001b[0m,\u001b[1;36m383\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1925\u001b[0m\u001b[1m}\u001b[0m - INFO - Running Fold \u001b[1;36m2\u001b[0m/\u001b[1;36m2\u001b[0m                           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2023</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">12</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">31</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">13:08:24</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">704</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1952</span><span style=\"font-weight: bold\">}</span> - INFO - Fold <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span>/<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> score: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9517333333333333</span>         \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2023\u001b[0m-\u001b[1;36m12\u001b[0m-\u001b[1;36m31\u001b[0m \u001b[1;92m13:08:24\u001b[0m,\u001b[1;36m704\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1952\u001b[0m\u001b[1m}\u001b[0m - INFO - Fold \u001b[1;36m2\u001b[0m/\u001b[1;36m2\u001b[0m score: \u001b[1;36m0.9517333333333333\u001b[0m         \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# cross validation loop usnig sklearn\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "\n",
    "kf = KFold(n_splits=5, shuffle=True, random_state=42)\n",
    "\n",
    "\n",
    "def _accuracy(y_true, y_pred):\n",
    "    return accuracy_score(y_true, y_pred[\"prediction\"].values)\n",
    "\n",
    "\n",
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter(\"ignore\")\n",
    "    cv_scores, oof_predictions = tabular_model.cross_validate(\n",
    "        cv=2, train=train, metric=_accuracy, return_oof=True, reset_datamodule=False\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "KFold Mean: 0.9298666666666666 | KFold SD: 0.021866666666666645\n"
     ]
    }
   ],
   "source": [
    "print(f\"KFold Mean: {np.mean(cv_scores)} | KFold SD: {np.std(cv_scores)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Using Low-Level API    \n",
    "\n",
    "Sometimes, we might want to do something more than just a plain, vanilla, cross validation. For a example, we might want to do a cross validation with mutiple metrics, or we might want to do a cross validation with a custom metric which relies on something other than the target and predictions. In such cases, we can use the low level API."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from rich import print"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">Fold:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0</span> | <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">Accuracy:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9293333333333333</span> | <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">AUC:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9807391279599271</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;31mFold:\u001b[0m \u001b[1;36m0\u001b[0m | \u001b[1;32mAccuracy:\u001b[0m \u001b[1;36m0.9293333333333333\u001b[0m | \u001b[1;32mAUC:\u001b[0m \u001b[1;36m0.9807391279599271\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">Fold:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1</span> | <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">Accuracy:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9146666666666666</span> | <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">AUC:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9736274684219891</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;31mFold:\u001b[0m \u001b[1;36m1\u001b[0m | \u001b[1;32mAccuracy:\u001b[0m \u001b[1;36m0.9146666666666666\u001b[0m | \u001b[1;32mAUC:\u001b[0m \u001b[1;36m0.9736274684219891\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">Fold:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2</span> | <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">Accuracy:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.924</span> | <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">AUC:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9730588808512757</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;31mFold:\u001b[0m \u001b[1;36m2\u001b[0m | \u001b[1;32mAccuracy:\u001b[0m \u001b[1;36m0.924\u001b[0m | \u001b[1;32mAUC:\u001b[0m \u001b[1;36m0.9730588808512757\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">Fold:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">3</span> | <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">Accuracy:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.922</span> | <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">AUC:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9757440627005844</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;31mFold:\u001b[0m \u001b[1;36m3\u001b[0m | \u001b[1;32mAccuracy:\u001b[0m \u001b[1;36m0.922\u001b[0m | \u001b[1;32mAUC:\u001b[0m \u001b[1;36m0.9757440627005844\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #800000; text-decoration-color: #800000; font-weight: bold\">Fold:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4</span> | <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">Accuracy:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9166666666666666</span> | <span style=\"color: #008000; text-decoration-color: #008000; font-weight: bold\">AUC:</span> <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9743804540010267</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;31mFold:\u001b[0m \u001b[1;36m4\u001b[0m | \u001b[1;32mAccuracy:\u001b[0m \u001b[1;36m0.9166666666666666\u001b[0m | \u001b[1;32mAUC:\u001b[0m \u001b[1;36m0.9743804540010267\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def _accuracy(y_true, y_pred):\n",
    "    return accuracy_score(y_true, y_pred[\"prediction\"].values)\n",
    "\n",
    "\n",
    "def _roc_auc_score(y_true, y_pred):\n",
    "    return roc_auc_score(y_true, y_pred[\"class_1_probability\"].values)\n",
    "\n",
    "\n",
    "kf = KFold(n_splits=5, shuffle=True, random_state=42)\n",
    "# Initialize the tabular model onece\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    "    verbose=False,\n",
    ")\n",
    "acc_metrics = []\n",
    "roc_metrics = []\n",
    "preds = []\n",
    "datamodule = None\n",
    "model = None\n",
    "with warnings.catch_warnings():\n",
    "    warnings.simplefilter(\"ignore\")\n",
    "    for fold, (train_idx, val_idx) in enumerate(kf.split(train)):\n",
    "        train_fold = train.iloc[train_idx]\n",
    "        val_fold = train.iloc[val_idx]\n",
    "        if datamodule is None:\n",
    "            # Initialize datamodule and model in the first fold\n",
    "            # uses train data from this fold to fit all transformers\n",
    "            datamodule = tabular_model.prepare_dataloader(\n",
    "                train=train_fold, validation=val_fold, seed=42\n",
    "            )\n",
    "            model = tabular_model.prepare_model(datamodule)\n",
    "        else:\n",
    "            # Creates a copy of the datamodule with same transformers but different train and validation data\n",
    "            datamodule = datamodule.copy(train=train_fold, validation=val_fold)\n",
    "        # Train the model\n",
    "        tabular_model.train(model, datamodule)\n",
    "        pred_df = tabular_model.predict(val_fold)\n",
    "        acc_metrics.append(_accuracy(val_fold[\"target\"], pred_df))\n",
    "        roc_metrics.append(_roc_auc_score(val_fold[\"target\"], pred_df))\n",
    "        print(\n",
    "            f\"[bold red]Fold:[/bold red] {fold} | [bold green]Accuracy:[/bold green]\"\n",
    "            f\" {acc_metrics[-1]} | [bold green]AUC:[/bold green] {roc_metrics[-1]}\"\n",
    "        )\n",
    "        # Reset the trained weights before next fold\n",
    "        tabular_model.model.reset_weights()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">KFold Accuracy Mean: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9213333333333333</span> | KFold Accuracy SD: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.005249338582674566</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "KFold Accuracy Mean: \u001b[1;36m0.9213333333333333\u001b[0m | KFold Accuracy SD: \u001b[1;36m0.005249338582674566\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">KFold AUC Mean: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.9755099987869607</span> | KFold AUC SD: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.002765008099674828</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "KFold AUC Mean: \u001b[1;36m0.9755099987869607\u001b[0m | KFold AUC SD: \u001b[1;36m0.002765008099674828\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print(\n",
    "    f\"KFold Accuracy Mean: {np.mean(acc_metrics)} | KFold Accuracy SD:\"\n",
    "    f\" {np.std(acc_metrics)}\"\n",
    ")\n",
    "print(f\"KFold AUC Mean: {np.mean(roc_metrics)} | KFold AUC SD: {np.std(roc_metrics)}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  },
  "vscode": {
   "interpreter": {
    "hash": "ad8d5d2789703c7b1c2f7bfaada1cbd3aa0ac53e2e4e1cae5da195f5520da229"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
